{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cross validation and models ensemble to achieve better test metrics\n",
    "\n",
    "Cross validation and models ensemble are popular strategies in machine learning and deep learning areas to achieve more accurate and more stable outputs.  \n",
    "A typical practice is:\n",
    "* Split all the training dataset into K folds.\n",
    "* Train K models with every K-1 folds data (cross validation).\n",
    "* Execute inference on the test data with all the K models.\n",
    "* Compute the average values with weights or vote the most common value as the final result (ensemble).\n",
    "<p>\n",
    "<img src=\"../figures/models_ensemble.png\" width=\"80%\" alt='models_ensemble'>\n",
    "</p>\n",
    "\n",
    "MONAI provides `CrossValidation` dataset to automatically split K-folds for cross validation.\n",
    "\n",
    "And also provides `EnsembleEvaluator` and `MeanEnsemble`, `VoteEnsemble` post transforms for models ensemble.\n",
    "\n",
    "This tutorial shows how to leverage cross validation and ensemble modules in MONAI to set up a program.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/master/modules/cross_validation_models_ensemble.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[ignite, nibabel, tqdm]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Copyright 2020 MONAI Consortium\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "\n",
    "from abc import ABC, abstractmethod\n",
    "import logging\n",
    "import os\n",
    "import tempfile\n",
    "import shutil\n",
    "import sys\n",
    "\n",
    "import nibabel as nib\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from monai.apps import CrossValidation\n",
    "from monai.config import print_config\n",
    "from monai.data import CacheDataset, DataLoader, create_test_image_3d\n",
    "from monai.engines import (\n",
    "    EnsembleEvaluator,\n",
    "    SupervisedEvaluator,\n",
    "    SupervisedTrainer\n",
    ")\n",
    "from monai.handlers import MeanDice, StatsHandler, ValidationHandler, from_engine\n",
    "from monai.inferers import SimpleInferer, SlidingWindowInferer\n",
    "from monai.losses import DiceLoss\n",
    "from monai.networks.nets import UNet\n",
    "from monai.transforms import (\n",
    "    Activationsd,\n",
    "    AsChannelFirstd,\n",
    "    AsDiscreted,\n",
    "    Compose,\n",
    "    LoadImaged,\n",
    "    MeanEnsembled,\n",
    "    RandCropByPosNegLabeld,\n",
    "    RandRotate90d,\n",
    "    ScaleIntensityd,\n",
    "    EnsureTyped,\n",
    "    VoteEnsembled,\n",
    ")\n",
    "from monai.utils import set_determinism\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup data directory\n",
    "\n",
    "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  \n",
    "This allows you to save results and reuse downloads.  \n",
    "If not specified a temporary directory will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/data/medical\n"
     ]
    }
   ],
   "source": [
    "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
    "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
    "print(root_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set determinism, logging, device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_determinism(seed=0)\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
    "device = torch.device(\"cuda:0\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate random (image, label) pairs\n",
    "\n",
    "Generate 60 pairs for the task, 50 for training and 10 for test.  \n",
    "And then split the 50 pairs into 5 folds to train 5 separate models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = os.path.join(root_dir, \"runs\")\n",
    "datalist = []\n",
    "\n",
    "if not os.path.exists(data_dir):\n",
    "    os.makedirs(data_dir)\n",
    "for i in range(60):\n",
    "    im, seg = create_test_image_3d(\n",
    "        128, 128, 128, num_seg_classes=1, channel_dim=-1)\n",
    "\n",
    "    n = nib.Nifti1Image(im, np.eye(4))\n",
    "    img_path = os.path.join(data_dir, f\"img{i}.nii.gz\")\n",
    "    nib.save(n, img_path)\n",
    "\n",
    "    n = nib.Nifti1Image(seg, np.eye(4))\n",
    "    seg_path = os.path.join(data_dir, f\"seg{i}.nii.gz\")\n",
    "    nib.save(n, seg_path)\n",
    "\n",
    "    datalist.append({\"image\": img_path, \"label\": seg_path})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the base class for cross validation datasets\n",
    "Here we leverage the `monai.apps.CrossValidation` to automatically split the datasets for cross validation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "@abstractmethod\n",
    "class CVDataset(ABC, CacheDataset):\n",
    "    \"\"\"\n",
    "    Base class to generate cross validation datasets.\n",
    "\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        data,\n",
    "        transform,\n",
    "        cache_num=sys.maxsize,\n",
    "        cache_rate=1.0,\n",
    "        num_workers=4,\n",
    "    ) -> None:\n",
    "        data = self._split_datalist(datalist=data)\n",
    "        CacheDataset.__init__(\n",
    "            self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers\n",
    "        )\n",
    "\n",
    "    def _split_datalist(self, datalist):\n",
    "        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup transforms for training and validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_transforms = Compose(\n",
    "    [\n",
    "        LoadImaged(keys=[\"image\", \"label\"]),\n",
    "        AsChannelFirstd(keys=[\"image\", \"label\"], channel_dim=-1),\n",
    "        ScaleIntensityd(keys=[\"image\", \"label\"]),\n",
    "        RandCropByPosNegLabeld(\n",
    "            keys=[\"image\", \"label\"],\n",
    "            label_key=\"label\",\n",
    "            spatial_size=[96, 96, 96],\n",
    "            pos=1,\n",
    "            neg=1,\n",
    "            num_samples=4,\n",
    "        ),\n",
    "        RandRotate90d(keys=[\"image\", \"label\"], prob=0.5, spatial_axes=[0, 2]),\n",
    "        EnsureTyped(keys=[\"image\", \"label\"]),\n",
    "    ]\n",
    ")\n",
    "val_transforms = Compose(\n",
    "    [\n",
    "        LoadImaged(keys=[\"image\", \"label\"]),\n",
    "        AsChannelFirstd(keys=[\"image\", \"label\"], channel_dim=-1),\n",
    "        ScaleIntensityd(keys=[\"image\", \"label\"]),\n",
    "        EnsureTyped(keys=[\"image\", \"label\"]),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get Datasets and DataLoaders for train, validation and test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "num = 5\n",
    "folds = list(range(num))\n",
    "\n",
    "cvdataset = CrossValidation(\n",
    "    dataset_cls=CVDataset,\n",
    "    data=datalist[0: 50],\n",
    "    nfolds=5,\n",
    "    seed=12345,\n",
    "    transform=train_transforms,\n",
    ")\n",
    "\n",
    "train_dss = [cvdataset.get_dataset(folds=folds[0: i] + folds[(i + 1):]) for i in folds]\n",
    "val_dss = [cvdataset.get_dataset(folds=i, transform=val_transforms) for i in range(num)]\n",
    "\n",
    "train_loaders = [DataLoader(train_dss[i], batch_size=2, shuffle=True, num_workers=4) for i in folds]\n",
    "val_loaders = [DataLoader(val_dss[i], batch_size=1, num_workers=4) for i in folds]\n",
    "\n",
    "test_ds = CacheDataset(data=datalist[50:], transform=val_transforms)\n",
    "test_loader = DataLoader(test_ds, batch_size=1, num_workers=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a training process based on workflows\n",
    "\n",
    "More usage examples of MONAI workflows are available at: [workflow examples](https://github.com/Project-MONAI/tutorials/tree/master/modules/engines)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def train(index):\n",
    "    net = UNet(\n",
    "        spatial_dims=3,\n",
    "        in_channels=1,\n",
    "        out_channels=1,\n",
    "        channels=(16, 32, 64, 128, 256),\n",
    "        strides=(2, 2, 2, 2),\n",
    "        num_res_units=2,\n",
    "    ).to(device)\n",
    "    loss = DiceLoss(sigmoid=True)\n",
    "    opt = torch.optim.Adam(net.parameters(), 1e-3)\n",
    "\n",
    "    val_post_transforms = Compose(\n",
    "        [EnsureTyped(keys=\"pred\"), Activationsd(keys=\"pred\", sigmoid=True), AsDiscreted(\n",
    "            keys=\"pred\", threshold_values=True)]\n",
    "    )\n",
    "\n",
    "    evaluator = SupervisedEvaluator(\n",
    "        device=device,\n",
    "        val_data_loader=val_loaders[index],\n",
    "        network=net,\n",
    "        inferer=SlidingWindowInferer(\n",
    "            roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),\n",
    "        postprocessing=val_post_transforms,\n",
    "        key_val_metric={\n",
    "            \"val_mean_dice\": MeanDice(\n",
    "                include_background=True,\n",
    "                output_transform=from_engine([\"pred\", \"label\"]),\n",
    "            )\n",
    "        },\n",
    "    )\n",
    "    train_handlers = [\n",
    "        ValidationHandler(validator=evaluator, interval=4, epoch_level=True),\n",
    "        StatsHandler(tag_name=\"train_loss\",\n",
    "                     output_transform=from_engine([\"loss\"], first=True)),\n",
    "    ]\n",
    "\n",
    "    trainer = SupervisedTrainer(\n",
    "        device=device,\n",
    "        max_epochs=4,\n",
    "        train_data_loader=train_loaders[index],\n",
    "        network=net,\n",
    "        optimizer=opt,\n",
    "        loss_function=loss,\n",
    "        inferer=SimpleInferer(),\n",
    "        amp=False,\n",
    "        train_handlers=train_handlers,\n",
    "    )\n",
    "    trainer.run()\n",
    "    return net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Execute 5 training processes and get 5 models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true
    },
    "tags": [
     "outputPrepend"
    ]
   },
   "outputs": [],
   "source": [
    "models = [train(i) for i in range(num)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define evaluation process based on `EnsembleEvaluator`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "def ensemble_evaluate(post_transforms, models):\n",
    "    evaluator = EnsembleEvaluator(\n",
    "        device=device,\n",
    "        val_data_loader=test_loader,\n",
    "        pred_keys=[\"pred0\", \"pred1\", \"pred2\", \"pred3\", \"pred4\"],\n",
    "        networks=models,\n",
    "        inferer=SlidingWindowInferer(\n",
    "            roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),\n",
    "        postprocessing=post_transforms,\n",
    "        key_val_metric={\n",
    "            \"test_mean_dice\": MeanDice(\n",
    "                include_background=True,\n",
    "                output_transform=from_engine([\"pred\", \"label\"]),\n",
    "            )\n",
    "        },\n",
    "    )\n",
    "    evaluator.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate the ensemble result with `MeanEnsemble`\n",
    "\n",
    "`EnsembleEvaluator` accepts a list of models for inference and outputs a list of predictions for further operations.\n",
    "\n",
    "Here the input data is a list or tuple of PyTorch Tensor with shape: [B, C, H, W, D].  \n",
    "The list represents the output data from 5 models.  \n",
    "And `MeanEnsemble` also can support to add `weights` for the input data:\n",
    "* The `weights` will be added to input data from highest dimension.\n",
    "* If the `weights` only has 1 dimension, it will be added to the `E` dimension of input data.\n",
    "* If the `weights` has 3 dimensions, it will be added to `E`, `B` and `C` dimensions.  \n",
    "For example, to ensemble 3 segmentation model outputs, every output has 4 channels(classes),  \n",
    "The input data shape can be: [3, B, 4, H, W, D], and add different `weights` for different classes.  \n",
    "So the `weights` shape can be: [3, 1, 4], like:  \n",
    "`weights = [[[1, 2, 3, 4]], [[4, 3, 2, 1]], [[1, 1, 1, 1]]]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "pycharm": {
     "is_executing": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:ignite.engine.engine.EnsembleEvaluator:Engine run resuming from iteration 0, epoch 0 until 1 epochs\n",
      "INFO:ignite.engine.engine.EnsembleEvaluator:Got new best metric of test_mean_dice: 0.9329929351806641\n",
      "INFO:ignite.engine.engine.EnsembleEvaluator:Epoch[1] Complete. Time taken: 00:00:15\n",
      "INFO:ignite.engine.engine.EnsembleEvaluator:Engine run complete. Time taken: 00:00:15\n"
     ]
    }
   ],
   "source": [
    "mean_post_transforms = Compose(\n",
    "    [\n",
    "        EnsureTyped(keys=[\"pred0\", \"pred1\", \"pred2\", \"pred3\", \"pred4\"]),\n",
    "        MeanEnsembled(\n",
    "            keys=[\"pred0\", \"pred1\", \"pred2\", \"pred3\", \"pred4\"],\n",
    "            output_key=\"pred\",\n",
    "            # in this particular example, we use validation metrics as weights\n",
    "            weights=[0.95, 0.94, 0.95, 0.94, 0.90],\n",
    "        ),\n",
    "        Activationsd(keys=\"pred\", sigmoid=True),\n",
    "        AsDiscreted(keys=\"pred\", threshold_values=True),\n",
    "    ]\n",
    ")\n",
    "ensemble_evaluate(mean_post_transforms, models)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate the ensemble result with `VoteEnsemble`\n",
    "\n",
    "Here the input data is a list or tuple of PyTorch Tensor with shape: [B, C, H, W, D].  \n",
    "The list represents the output data from 5 models.\n",
    "\n",
    "Note that:\n",
    "* `VoteEnsemble` expects the input data is discrete values.\n",
    "* Input data can be multiple channels data in One-Hot format or single channel data.\n",
    "* It will vote to select the most common data between items.\n",
    "* The output data has the same shape as every item of the input data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "pycharm": {
     "is_executing": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:ignite.engine.engine.EnsembleEvaluator:Engine run resuming from iteration 0, epoch 0 until 1 epochs\n",
      "INFO:ignite.engine.engine.EnsembleEvaluator:Got new best metric of test_mean_dice: 0.9344435930252075\n",
      "INFO:ignite.engine.engine.EnsembleEvaluator:Epoch[1] Complete. Time taken: 00:00:15\n",
      "INFO:ignite.engine.engine.EnsembleEvaluator:Engine run complete. Time taken: 00:00:16\n"
     ]
    }
   ],
   "source": [
    "vote_post_transforms = Compose(\n",
    "    [\n",
    "        EnsureTyped(keys=[\"pred0\", \"pred1\", \"pred2\", \"pred3\", \"pred4\"]),\n",
    "        Activationsd(keys=[\"pred0\", \"pred1\", \"pred2\",\n",
    "                           \"pred3\", \"pred4\"], sigmoid=True),\n",
    "        # transform data into discrete before voting\n",
    "        AsDiscreted(keys=[\"pred0\", \"pred1\", \"pred2\", \"pred3\",\n",
    "                          \"pred4\"], threshold_values=True),\n",
    "        VoteEnsembled(keys=[\"pred0\", \"pred1\", \"pred2\",\n",
    "                            \"pred3\", \"pred4\"], output_key=\"pred\"),\n",
    "    ]\n",
    ")\n",
    "ensemble_evaluate(vote_post_transforms, models)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleanup data directory\n",
    "\n",
    "Remove directory if a temporary was used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "if directory is None:\n",
    "    shutil.rmtree(root_dir)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
